package darecreek.vfuAutotest.alu

import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
import chisel3._
import chiseltest.WriteVcdAnnotation
import scala.reflect.io.File
import scala.reflect.runtime.universe._
import scala.util.control.Breaks._
import scala.util.Random

import darecreek.exu.vfu.perm._
import darecreek.exu.vfu._
import darecreek.exu.vfu.alu._
import darecreek.exu.vfu.VInstructions._
import chipsalliance.rocketchip.config.Parameters

class VslideupvxFSMTestBehavior extends slidefsm("vslideup.vx.data", ctrlBundles.vslideup_vx, "-", "vslideup_vx") {}
class VslidedownvxFSMTestBehavior extends slidefsm("vslidedown.vx.data", ctrlBundles.vslidedown_vx, "-", "vslidedown_vx") {}

class Vslide1upvxFSMTestBehavior extends slidefsm("vslide1up.vx.data", ctrlBundles.vslide1up_vx, "-", "vslide1up_vx") {}
class Vslide1downvxFSMTestBehavior extends slidefsm("vslide1down.vx.data", ctrlBundles.vslide1down_vx, "-", "vslide1down_vx") {}

class VrgathervvFSMTestBehavior extends slidefsm("vrgather.vv.data", ctrlBundles.vrgather_vv, "-", "vrgather_vv") {}
class VrgathervxFSMTestBehavior extends slidefsm("vrgather.vx.data", ctrlBundles.vrgather_vx, "-", "vrgather_vx") {}
class Vrgatherei16vvFSMTestBehavior extends slidefsm("vrgatherei16.vv.data", ctrlBundles.vrgatherei16_vv, "-", "vrgatherei16_vv") {}

class VcompressvmFSMTestBehavior extends slidefsm("vcompress.vm.data", ctrlBundles.vcompress_vm, "-", "vcompress_vm") {}


class VslideupviFSMTestBehavior extends slidefsm("vslideup.vi.data", ctrlBundles.vslideup_vi, "-", "vslideup_vi") {}
class VslidedownviFSMTestBehavior extends slidefsm("vslidedown.vi.data", ctrlBundles.vslidedown_vi, "-", "vslidedown_vi") {}
class VrgatherviFSMTestBehavior extends slidefsm("vrgather.vi.data", ctrlBundles.vrgather_vi, "-", "vrgather_vi") {}
class Vfslide1upvfFSMTestBehavior extends slidefsm("vfslide1up.vf.data", ctrlBundles.vfslide1up_vf, "-", "vfslide1up_vf") {}
class Vfslide1downvfFSMTestBehavior extends slidefsm("vfslide1down.vf.data", ctrlBundles.vfslide1down_vf, "-", "vfslide1down_vf") {}

object RandomGen {
    val rand = new Random(seed = 42)
}

class slidefsm(fn : String, cb : CtrlBundle, s : String, instid : String) extends TestBehavior(fn, cb, s, instid) {
    
    // val rand = new scala.util.Random
    val maxBlocks = 20
    var blocks = 0
    val vs1base = 100
    val vs2base = 110
    val oldvdbase = 120
    val vdbase = 130
    val maskbase = 140
    var robIdxValid = false

    val useFlushDebug = false

    var testCountsDebug = 0

    override def getDut() : Module               = {
        val dut = new Permutation
        // TestHarnessPerm.test_init(dut)
        return dut
    }

    case class ClkIdx(clk : Int, idx: Int)
    case class ClkAry(clk : Int, ary: Array[String])

    def randomBool() : Boolean = {
        if(RandomGen.rand.nextInt(100) > 50 && blocks < maxBlocks) {
            blocks += 1
            return true
        }
        return false
    }

    def resComp(goldenVd : Array[String], vd : Array[String], n_inputs : Int, simi : Map[String, String]) : Unit = {
        for(j <- 0 until n_inputs) {
            Logger.printvds(vd(j), goldenVd(n_inputs - 1 - j))
            val vdres = goldenVd(n_inputs - 1 - j).equals(vd(j))
            // println("(vdbase + j)", (vdbase + j))
            // println("(vdbase + j).toString", (vdbase + j).toString)
            if (!vdres) dump(simi, vd(j), goldenVd(n_inputs - 1 - j), fault_wb=(vdbase + j).toString)
            assert(vdres)
        }
    }

    def stageTwo(
        dut : Permutation, preg_to_value : Map[Int, String], 
        n_res : Int, ctrlBundle : CtrlBundle, srcBundle : FSMSrcBundle
    ) : Array[String] = {
        // ====================================================================
        // stage 2: checking for FSM's read requests
        // ====================================================================
        val WB_DELAY = 3
        val RD_DELAY = 1
        val LOOP_MAX = 100

        var vd : BigInt = 0

        var clock_counter = 0
        var wb_idxs : Seq[ClkIdx] = Seq()
        var wb_counts = 0
        var rd_idxs : Seq[ClkIdx] = Seq()
        var rd_counts = 0

        var stop_iter = LOOP_MAX
        var done = false

        var res_vds : Array[String] = Array()
        var randomblock = false
        breakable{ while(true) {

            var rd_en = dut.io.out.rd_en.peek().litValue.toInt
            var rd_preg_idx = dut.io.out.rd_preg_idx.peek().litValue.toInt
            
            
            if(rd_en == 1) {
                rd_idxs :+= ClkIdx(clock_counter + RD_DELAY, rd_preg_idx)
            }
            

            var fsmSrcBundle = srcBundle.copy()
            fsmSrcBundle.rdata="h0"
            fsmSrcBundle.rvalid=false
            var fsmCtrl = ctrlBundle.copy()
            
            // Stage 2.1: see if values are written back ============================
            /*if(wb_counts < wb_idxs.length && wb_idxs(wb_counts).clk == clock_counter) {
                vd = dut.io.out.wb_data.peek().litValue
                res_vds :+= f"h$vd%032x"
                println("res: ", f"h$vd%032x")
                wb_counts += 1
            }*/

            // Stage 2.2: if busy is down ============================================
            if(!done && dut.io.out.perm_busy.peek().litValue.toInt != 1) {
                stop_iter = clock_counter + 0 //.. if there's any delay
                done = true
            }
            if(clock_counter == stop_iter) {
                break
            }

            // Stage 2.3: if now is the time to send data to the FSM.. =============
            if(rd_counts < rd_idxs.length && rd_idxs(rd_counts).clk == clock_counter) {
                val rd_idx = rd_idxs(rd_counts).idx

                println(s"rd_idx: ${rd_idx}")
                fsmSrcBundle.rdata=preg_to_value(rd_idx)
                fsmSrcBundle.rvalid=true
                    
                rd_counts += 1
            }

            // ================================================
            // 10.27 add random flush
            var robIdx = (true, 1)
            robIdxValid = randomFlush()
            /*if (robIdxValid) {
                robIdx = (true, 1)
            }*/

            fsmCtrl.robIdx = robIdx

            dut.io.in.poke(genFSMInput(
                fsmSrcBundle,
                fsmCtrl
            ))
            dut.io.redirect.poke(genFSMRedirect((robIdxValid, robIdxValid, 0)))

            // Stage 2.4: see if any wb value is there to be written =============
            val fsm_wb_vld = dut.io.out.wb_vld.peek().litValue.toInt
            if (fsm_wb_vld == 1) {
                wb_idxs :+= ClkIdx(clock_counter + WB_DELAY, 0)

                // 12.5: wb_data comes at the same cycle with wb_vld
                vd = dut.io.out.wb_data.peek().litValue
                res_vds :+= f"h$vd%032x"
                println("res: ", f"h$vd%032x")
                wb_counts += 1
            }

            dut.clock.step(1)
            clock_counter += 1

            if (robIdxValid) {
                // flushed
                println("flushed")

                fsmSrcBundle.uop_valid = false
                fsmSrcBundle.rdata="h0"
                fsmSrcBundle.rvalid=false
                fsmCtrl = ctrlBundle.copy()

                // turning off redirect bits
                dut.io.in.poke(genFSMInput(
                    fsmSrcBundle,
                    fsmCtrl
                ))
                dut.io.redirect.poke(genFSMRedirect())
                
                while(dut.io.out.perm_busy.peek().litValue.toInt == 1) {
                    dut.clock.step(1)
                    clock_counter += 1
                }
                dut.clock.step(1)
                return res_vds
            }
        }}

        if(clock_counter >= LOOP_MAX) println(s"!!!!!!!! Exceeds LOOP_MAX !!!!!!!! FSM has not done work after ${LOOP_MAX} cycles")

        for(j <- 0 until res_vds.length) {
            println(f"res_vds($j)", res_vds(j))
            println(f"wb_idxs($j)", wb_idxs(j))
        }

        return res_vds
    }


    def normalFSMTestMultiple(simi:Map[String,String],ctrl:CtrlBundle,s:String, dut:Permutation) : Unit = {
        val vs2data = UtilFuncs.multilmuldatahandle(simi.get("VS2").get)
        val oldvddata = UtilFuncs.multilmuldatahandle(simi.get("OLD_VD").get)
        var mask = Array("hffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff")
        if(simi.get("MASK") != None)
            mask = UtilFuncs.multilmuldatahandle(simi.get("MASK").get)
        val vflmul = simi.get("vflmul").get
        val vxsat = simi.get("vxsat").get.toInt == 1
        val expectvd = UtilFuncs.multilmuldatahandle(simi.get("VD").get)
        val vsew = UtilFuncs.vsewconvert(simi.get("vsew").get)
        val vm = (simi.get("vm").get.toInt == 1)
        val ma = (simi.get("ma").get.toInt == 1)
        val ta = (simi.get("ta").get.toInt == 1)
        val vl = simi.get("vl").get.toInt
        val vxrm = simi.get("vxrm").get.toInt
        val vstart = getVstart(simi)

        val hasRS1 = simi.get("RS1") != None
        val hasVS1 = simi.get("VS1") != None
        val hasFS1 = simi.get("FS1") != None

        var vs1data : Array[String] = Array()
        if(hasRS1)
            vs1data = UtilFuncs.multilmuldatahandle(simi.get("RS1").get)
        if(hasVS1)
            vs1data = UtilFuncs.multilmuldatahandle(simi.get("VS1").get)
        if(hasFS1) {
            vs1data = UtilFuncs.multilmuldatahandle(simi.get("FS1").get)
            vs1data(0) = s"h${vs1data(0).slice(17, 33)}"
        }

        var n_inputs = 1
        if(vflmul == "2.000000") n_inputs = 2
        if(vflmul == "4.000000") n_inputs = 4
        if(vflmul == "8.000000") n_inputs = 8

        // index map, from dut requested index to vs1/vs2/old_vd value
        var preg_to_value : Map[Int, String] = Map()

        // ========================================================================================================================
        // stage 1: sending data to FSM
        // ========================================================================================================================
        var vs1_preg_idx : Seq[Int] = Seq()
        var vs2_preg_idx : Seq[Int] = Seq()
        var old_vd_preg_idx : Seq[Int] = Seq()

        preg_to_value = preg_to_value + (maskbase -> mask(0))
        for(j <- 0 until n_inputs) {
            // add vs1 to index map
            if(j == 0 && (hasRS1 || hasFS1)) {
                preg_to_value = preg_to_value + ((vs1base + j) -> vs1data(0))
                vs1_preg_idx :+= (vs1base + j)
            }else if(hasVS1) {
                preg_to_value = preg_to_value + ((vs1base + j) -> vs1data(n_inputs - 1 - j))
                vs1_preg_idx :+= (vs1base + j)
            }

            // add vs2 to index map
            preg_to_value = preg_to_value + ((vs2base + j) -> vs2data(n_inputs - 1 - j))
            vs2_preg_idx :+= (vs2base + j)

            // add old vd to index map
            preg_to_value = preg_to_value + ((oldvdbase + j) -> oldvddata(n_inputs - 1 - j))
            old_vd_preg_idx :+= (oldvdbase + j)

            // println(s"${vs1base + j} ${vs2base + j} ${oldvdbase + j}")
        }

        var rs1value = "h0"
        if (hasRS1 || hasFS1) rs1value = vs1data(0)

        /*// ================================================
        // 10.27 add flush
        var robIdx = (false, 0)
        val robIdxValid = randomBool()
        if (robIdxValid) {
            robIdx = (true, 1)
        }*/

        val srcBundle = FSMSrcBundle(
            rs1=rs1value,
            vs1_preg_idx=vs1_preg_idx,
            vs2_preg_idx=vs2_preg_idx,
            old_vd_preg_idx=old_vd_preg_idx,
            mask_preg_idx=maskbase,
            uop_valid=true,
        )

        val ctrlBundle = ctrl.copy(
            vsew=vsew,
            vl=vl,
            vs1_imm=getImm(simi),
            vlmul = UtilFuncs.lmulconvert(vflmul).toInt, 
            ma = ma,
            ta = ta,
            vm = vm,
            uopIdx=0,
            uopEnd=true,
            vxrm = vxrm,
            vstart = vstart,
            // robIdx = robIdx
            robIdx = (true, 1)
        )


        // ==========================================================================================================================
        dut.io.in.poke(genFSMInput(
            srcBundle,
            ctrlBundle
        ))

        dut.io.redirect.poke(genFSMRedirect())

        dut.clock.step(1)

        // ==========================================================================================================================
        val res_vds = stageTwo(
            dut, preg_to_value, n_inputs, ctrlBundle, srcBundle
        )
        if (robIdxValid) {
            println("robIdxValid = true, flush this instruction")
            return
        }
        resComp(expectvd, res_vds, n_inputs, simi)
    }

    def ei16FSMTestMultiple(simi:Map[String,String],ctrl:CtrlBundle,s:String, dut:Permutation) : Unit = {
        val vs2data = UtilFuncs.multilmuldatahandle(simi.get("VS2").get)
        val oldvddata = UtilFuncs.multilmuldatahandle(simi.get("OLD_VD").get)
        var mask = Array("hffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff")
        if(simi.get("MASK") != None)
            mask = UtilFuncs.multilmuldatahandle(simi.get("MASK").get)
        val vflmul = simi.get("vflmul").get
        val vxsat = simi.get("vxsat").get.toInt == 1
        val expectvd = UtilFuncs.multilmuldatahandle(simi.get("VD").get)
        val vsew = UtilFuncs.vsewconvert(simi.get("vsew").get)
        val vm = (simi.get("vm").get.toInt == 1)
        val ma = (simi.get("ma").get.toInt == 1)
        val ta = (simi.get("ta").get.toInt == 1)
        val vl = simi.get("vl").get.toInt
        val vxrm = simi.get("vxrm").get.toInt
        val vstart = getVstart(simi)

        var vs1data = UtilFuncs.multilmuldatahandle(simi.get("VS1").get)
        // =========================================================================================================
        var n_inputs = 1
        if(vflmul == "2.000000") n_inputs = 2
        if(vflmul == "4.000000") n_inputs = 4
        if(vflmul == "8.000000") n_inputs = 8

        var vs1_n_inputs = n_inputs

        simi.get("vsew").get.toInt match {
            case 8 => { 
                if (vflmul == "1.000000" ||
                    vflmul == "2.000000" ||
                    vflmul == "4.000000" ||
                    vflmul == "8.000000") vs1_n_inputs *= 2 }
            case 16 => {  }
            case 32 => { vs1_n_inputs = Math.max(vs1_n_inputs / 2, 1) }
            case 64 => { vs1_n_inputs = Math.max(vs1_n_inputs / 4, 1) }
        }

        var preg_to_value : Map[Int, String] = Map()

        // =================================================================================================================
        // stage 1: sending data to FSM
        // =================================================================================================================
        var vs1_preg_idx : Seq[Int] = Seq()
        var vs2_preg_idx : Seq[Int] = Seq()
        var old_vd_preg_idx : Seq[Int] = Seq()

        preg_to_value = preg_to_value + (maskbase -> mask(0))
        for(j <- 0 until n_inputs) {
            
            // mapping vs2 index to value
            preg_to_value = preg_to_value + ((vs2base + j) -> vs2data(n_inputs - 1 - j))
            vs2_preg_idx :+= (vs2base + j)

            // mapping old_vd index to value
            preg_to_value = preg_to_value + ((oldvdbase + j) -> oldvddata(n_inputs - 1 - j))
            var old_vd_each_step = 1
            if (vs1_n_inputs > n_inputs) {
                // vsew=8
                old_vd_each_step = vs1_n_inputs / n_inputs
            }

            for (_ <- 0 until old_vd_each_step) {
                old_vd_preg_idx :+= (oldvdbase + j)
            }

            // println(s"${vs1base + j} ${vs2base + j} ${oldvdbase + j}")
        }

        for(j <- 0 until vs1_n_inputs) {
            var each_step = 1
            if (n_inputs > vs1_n_inputs) {
                each_step = n_inputs / vs1_n_inputs
            }
            // mapping vs1 index to value
            preg_to_value = preg_to_value + ((vs1base + j) -> vs1data(vs1_n_inputs - 1 - j))
            for (_ <- 0 until each_step) {
                vs1_preg_idx :+= (vs1base + j)
            }

            println(s"vsew ${simi.get("vsew").get.toInt}, ${vs1_n_inputs}, ${n_inputs}")
            println(s"vs1base + j: ${vs1base + j}, ${vs1data(vs1_n_inputs - 1 - j)}")
        }

        // ========================================================================================================================
        val srcBundle = FSMSrcBundle(
            rs1="h0",
            vs1_preg_idx=vs1_preg_idx,
            vs2_preg_idx=vs2_preg_idx,
            old_vd_preg_idx=old_vd_preg_idx,
            mask_preg_idx=maskbase,
            uop_valid=true,
        )

        val ctrlBundle = ctrl.copy(
            vsew=vsew,
            vl=vl,
            vlmul = UtilFuncs.lmulconvert(vflmul).toInt, 
            ma = ma,
            ta = ta,
            vm = vm,
            uopIdx=0,
            uopEnd=true,
            vxrm = vxrm,
            vstart = vstart,
            robIdx = (true, 1)
        )
        
        dut.io.in.poke(genFSMInput(
            srcBundle,
            ctrlBundle
        ))
        dut.io.redirect.poke(genFSMRedirect())
        dut.clock.step(1)

        // ========================================================================================================================
        val res_vds = stageTwo(
            dut, preg_to_value, n_inputs, ctrlBundle, srcBundle
        )
        
        if (robIdxValid) {
            println("robIdxValid = true, flush this instruction")
            return
        }
        resComp(expectvd, res_vds, n_inputs, simi)
    }

    override def testMultiple(simi:Map[String,String],ctrl:CtrlBundle,s:String, dut:Permutation) : Unit = {
        blocks = 0
        robIdxValid = false
        testCountsDebug += 1
        if(instid.equals("vrgatherei16_vv")) ei16FSMTestMultiple(simi, ctrl, s, dut)
        else normalFSMTestMultiple(simi, ctrl, s, dut)
    }

    override def testSingle(simi:Map[String,String],ctrl:CtrlBundle,s:String, dut:Permutation) : Unit = {
        return testMultiple(simi,ctrl,s, dut)
    }
}